[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
[PyTorch] Add dtype information to QuantizedTensorStorage class#2676ptrendx wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a Key changes and concerns:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["C++ Quantizer::create_tensor(dtype)"] -->|"internal=true"| B["Float8TensorStorage.__new__\n(fake_dtype=GetATenDType(dtype))"]
A -->|"internal=false"| C["Float8Tensor.__new__\n(dtype=GetATenDType(dtype))"]
B -->|"cls is Storage"| D["object.__new__()\ninstance._dtype = fake_dtype"]
B -->|"cls is Tensor subclass"| E["super().__new__(cls, fake_dtype=fake_dtype)"]
C --> E
E --> F["QuantizedTensor.__new__(dtype, fake_dtype)\nValidate: fake_dtype == dtype if not None\ninstance._dtype = dtype"]
F --> G["QuantizedTensor._dtype set"]
D --> G
G -->|"dequantize() called"| H{"dtype arg?"}
H -->|"None"| I["use self._dtype"]
H -->|"explicit"| J["use explicit dtype"]
I --> K["dequantize to correct high-precision dtype"]
J --> K
style D fill:#90EE90
style G fill:#90EE90
style F fill:#FFB6C1,stroke:#FF0000
note1["⚠️ Validation fails for\nmake_like(dtype=X) when X != _dtype"]
F -.-> note1
Last reviewed commit: 369f8b5 |
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this is a big improvement. I have some naming nits.
| shape: Iterable[int], | ||
| dtype: torch.dtype, | ||
| *, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
Isn't this redundant with the dtype kwarg?
There was a problem hiding this comment.
This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.
There was a problem hiding this comment.
Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.
| data: Optional[torch.Tensor], | ||
| fp8_scale_inv: torch.Tensor, | ||
| fp8_dtype: TE_DType, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.
| fake_dtype: Optional[torch.dtype] = None, | |
| dtype: Optional[torch.dtype] = None, |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
Still not a fan of fake_dtype, but approving to unblock.
|
/te-ci pytorch |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch |
| if fake_dtype is not None and fake_dtype != dtype: | ||
| raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})") |
There was a problem hiding this comment.
Validation breaks existing make_like call sites
This new guard will cause regressions on every call to make_like(tensor, dtype=X) where X differs from tensor._dtype, because get_metadata() now always injects fake_dtype=self._dtype into kwargs, and QuantizedTensor.__new__ is then called with both dtype=X (the intended new dtype) and fake_dtype=old_dtype (from metadata).
Confirmed breakage paths:
transformer_engine/pytorch/tensor/__init__.py:63— module cast utility (model.half(),model.bfloat16(), etc.) callstensor.__class__.make_like(tensor, dtype=dtype)for everyQuantizedTensor; wheneverdtype != tensor._dtypethe model cast will raiseValueError.attention/dot_product_attention/context_parallel.py— 10+ call sites of the formFloat8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype)wherefwd_nominal_dtypemay differ fromx._dtype.attention/dot_product_attention/utils.py:2220— same pattern.
The root cause is that fake_dtype is being included in get_metadata() but the constructor-level guard then rejects any case where the caller wants to create a clone at a different nominal dtype. Either:
- Remove the guard (it is redundant for the full-tensor path, because
QuantizedTensor.__new__already sets_dtype = dtype), or - Override
fake_dtypeinQuantizedTensor.make_likeso it matches the requesteddtypebefore calling the constructor.
| _dtype: torch.dtype | ||
| _quantizer: Optional[Quantizer] |
There was a problem hiding this comment.
No lazy-init guard for _dtype on storage objects
QuantizedTensor.dtype (line 405–409) has a hasattr(self, "_dtype") lazy-initializer that protects against deserialization from pre-PR checkpoints. QuantizedTensorStorage and its subclasses have no equivalent protection — _dtype: torch.dtype is only a class-level annotation, not a default value.
If an *TensorStorage object is unpickled from a checkpoint that was saved before this PR, the first call to .dequantize() (or the distributed-ops in distributed.py that now access inp._dtype) will raise AttributeError: _dtype.
Consider adding a similar lazy fallback in the dequantize methods, e.g.:
if dtype is None:
dtype = getattr(self, "_dtype", torch.float32)
Description
This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: